# R script for estimating state-level excess deaths
# Step II: Applying the migration-adjusted method to full-count and linked census data

# START SCRIPT

rm(list = ls())

# Load libraries ----------------------------------------------------------

if (!require('tidyverse')) install.packages('tidyverse')
if (!require('readxl')) install.packages('readxl')
if (!require('writexl')) install.packages('writexl')

# Read cohort size by sex and state ---------------------------------------

census_years <- seq(1850, 1880, 10)

## Read native-born white male cohort sizes by state ----

ncohort_males <- lapply(census_years, function(year) readxl::read_xlsx('DAT_Ncohort_nbwm_1850to1880_byState.xlsx', sheet = as.character(year))) |>
  `names<-`(census_years)

ncohort_male_1850 <- ncohort_males[['1850']]
ncohort_male_1860 <- ncohort_males[['1860']]
ncohort_male_1870 <- ncohort_males[['1870']]
ncohort_male_1880 <- ncohort_males[['1880']]

## Read native-born white female cohort sizes by state ----

ncohort_females <- lapply(census_years, function(year) readxl::read_xlsx('DAT_Ncohort_nbwf_1850to1880_byState.xlsx', sheet = as.character(year))) |>
  `names<-`(census_years)

ncohort_female_1850 <- ncohort_females[['1850']]
ncohort_female_1860 <- ncohort_females[['1860']]
ncohort_female_1870 <- ncohort_females[['1870']]
ncohort_female_1880 <- ncohort_females[['1880']]

# Compute survival rate by sex and state ----------------------------------

compute_psurvival_by_sexstate <- function(nc_t1, nc_t2) {
  if (!all(nc_t1[["STATE"]]==nc_t2[["STATE"]])) nc_t2 <- nc_t2[nc_t2[["STATE"]] %in% nc_t1[["STATE"]], ]
  nc_t1 <- nc_t1[order(nc_t1[["STATE"]]),]
  nc_t2 <- nc_t2[order(nc_t2[["STATE"]]),]
  rs_t1t2 <- nc_t2[,c(3:6)] / nc_t1[,c(2:5)] 
  names(rs_t1t2) <- c('5-14','15-24','25-34','35-44')
  cbind.data.frame(STATE = nc_t1[["STATE"]], rs_t1t2)
}

## Compute male survival rates by state ---- 
rsurvival_male_1850to60 <- compute_psurvival_by_sexstate(ncohort_male_1850, ncohort_male_1860)
rsurvival_male_1860to70 <- compute_psurvival_by_sexstate(ncohort_male_1860, ncohort_male_1870)
rsurvival_male_1870to80 <- compute_psurvival_by_sexstate(ncohort_male_1870, ncohort_male_1880)

rsurvival_males <- list("1850-60" = rsurvival_male_1850to60, 
                        "1860-70" = rsurvival_male_1860to70, 
                        "1870-80" = rsurvival_male_1870to80)

## Compute female survival rates by state ----
rsurvival_female_1850to60 <- compute_psurvival_by_sexstate(ncohort_female_1850, ncohort_female_1860)
rsurvival_female_1860to70 <- compute_psurvival_by_sexstate(ncohort_female_1860, ncohort_female_1870)
rsurvival_female_1870to80 <- compute_psurvival_by_sexstate(ncohort_female_1870, ncohort_female_1880)

rsurvival_females <- list("1850-60" = rsurvival_female_1850to60, 
                          "1860-70" = rsurvival_female_1860to70, 
                          "1870-80" = rsurvival_female_1870to80)

# Add state groups to estimates -------------------------------------------

state_group <- readxl::read_xlsx('DICT_state_battleside.xlsx')

add_state_groups <- function(tab, state_group_df = state_group) {
  dplyr::left_join(tab, state_group_df, by = "STATE") |>
    dplyr::relocate(GROUP, .after = STATE) |>
    dplyr::mutate(GROUP=factor(GROUP, levels = c("Union","Border","Confederate"))) |>
    dplyr::arrange(GROUP, STATE)
}

## Group cohort sizes ----
ncohort_males <- lapply(ncohort_males, add_state_groups)
ncohort_females <- lapply(ncohort_females, add_state_groups)

## Group survival rates ----
rsurvival_males <- lapply(rsurvival_males, add_state_groups)
rsurvival_females <- lapply(rsurvival_females, add_state_groups)

# Select key cohort size and survival rate estimates ----------------------

## Select 1860 male cohort sizes ---- 

ntotal_male_1860 <- ncohort_males[['1860']] |> dplyr::select(-`45-54`)

## Select survival rate estimates ----

rsurvival_male_1850to60 = rsurvival_males[["1850-60"]]
rsurvival_male_1860to70 = rsurvival_males[["1860-70"]]
rsurvival_male_1870to80 = rsurvival_males[["1870-80"]]

# Set state groups --------------------------------------------------------

state_groups <- Reduce(
  full_join, 
  lapply(list(rsurvival_male_1850to60, rsurvival_male_1860to70, rsurvival_male_1870to80), 
         function(df) df[,c("STATE","GROUP")])
) |>
  dplyr::mutate(GROUP = factor(GROUP, levels=c("Union","Border","Confederate"))) |>
  dplyr::arrange(GROUP, STATE)

state_vec = state_groups$STATE

# Read linked cohort size (first-period) ----------------------------------

census_periods <- c("1850-60", "1860-70", "1870-80")

nrecords_males <- lapply(census_periods, function(period) readxl::read_xlsx('DAT_Ncohort_nbwm_1850to1880_byState_Linked.xlsx', sheet = period)) |>
  `names<-`(census_periods)

nrecords_male_1850to60 <- nrecords_males$`1850-60`
nrecords_male_1860to70 <- nrecords_males$`1860-70`
nrecords_male_1870to80 <- nrecords_males$`1870-80`
  
# Compute out-migration size ----------------------------------------------

## Read count of out-migration ----

nout_males <- lapply(census_periods, function(period) readxl::read_xlsx('DAT_Noutmig_nbwm_1850to1880_byState_Linked.xlsx', sheet = period)) |>
  `names<-`(census_periods)

nout_male_1850to60 <- nout_males$`1850-60`
nout_male_1860to70 <- nout_males$`1860-70`
nout_male_1870to80 <- nout_males$`1870-80`

## Compute rate of out-migration ----

compute_rout_migration <- function(nout, ntot) {
  nout = nout[order(nout[["STATE"]]),]
  state = nout[["STATE"]]
  ntot = ntot[ntot[["STATE"]]%in%state,]
  ntot = ntot[order(ntot[["STATE"]]),]
  ntot = ntot[,c("5-14","15-24","25-34","35-44")]
  nout = nout[,c("5-14","15-24","25-34","35-44")]
  rout = nout/ntot
  names(rout) <- c("5-14","15-24","25-34","35-44")
  rout = cbind.data.frame("STATE"=state,rout)
  rout = rout[order(rout[["STATE"]]),]
  return(rout)
}

rout_male_1850to60 <- compute_rout_migration(nout_male_1850to60, nrecords_male_1850to60)
rout_male_1860to70 <- compute_rout_migration(nout_male_1860to70, nrecords_male_1860to70)
rout_male_1870to80 <- compute_rout_migration(nout_male_1870to80, nrecords_male_1870to80)

rout_males <- list(
  "1850-60" = rout_male_1850to60,
  "1860-70" = rout_male_1860to70,
  "1870-80" = rout_male_1870to80
)

# Compute in-migration size -----------------------------------------------

## Read count of in-migration ----

nin_males <- lapply(census_periods, function(period) readxl::read_xlsx('DAT_Ninmig_nbwm_1850to1880_byState_Linked.xlsx', sheet = period)) |>
  `names<-`(census_periods)

nin_male_1850to60 <- nin_males$`1850-60`
nin_male_1860to70 <- nin_males$`1860-70`
nin_male_1870to80 <- nin_males$`1870-80`

## Compute rate of in-migration ----

compute_rin_migration <- function(nin, ntot) {
  nin = nin[order(nin[["STATE"]]),]
  state = nin[["STATE"]]
  ntot = ntot[ntot[["STATE"]]%in%state,]
  ntot = ntot[order(ntot[["STATE"]]),]
  ntot = ntot[,c("5-14","15-24","25-34","35-44")]
  nin = nin[,c("5-14","15-24","25-34","35-44")]
  rin = nin/ntot
  names(rin) <- c("5-14","15-24","25-34","35-44")
  rin = cbind.data.frame("STATE"=state,rin)
  rin = rin[order(rin[["STATE"]]),]
  return(rin)
}

rin_male_1850to60 <- compute_rin_migration(nin_male_1850to60 |> dplyr::filter(!STATE%in%c("Kansas","Nevada")), 
                                           nrecords_male_1850to60) # Kansas, Nevada not available in 1850 census
rin_male_1860to70 <- compute_rin_migration(nin_male_1860to70, nrecords_male_1860to70)
rin_male_1870to80 <- compute_rin_migration(nin_male_1870to80, nrecords_male_1870to80)

rin_males <- list(
  "1850-60" = rin_male_1850to60,
  "1860-70" = rin_male_1860to70,
  "1870-80" = rin_male_1870to80
)

# Compute net out-migration rate ------------------------------------------

## Note: We only need the net rate not count here

compute_rnetout_migration <- function(rout, rin) {
  rout <- rout[order(rout[["STATE"]]),]
  rin <- rin[order(rin[["STATE"]]),]
  if (!all.equal(rout[["STATE"]], rin[["STATE"]])) {
    stop("State orders do not match!")
  } else {
    rnetout <- rout[,c("5-14","15-24","25-34","35-44")] - rin[,c("5-14","15-24","25-34","35-44")]
    names(rnetout) <- c("5-14","15-24","25-34","35-44")
    rnetout <- cbind.data.frame(STATE = rout[["STATE"]], rnetout)
    return(rnetout)
  } 
}

rnetout_male_1850to60 <- compute_rnetout_migration(rout_male_1850to60, rin_male_1850to60)
rnetout_male_1860to70 <- compute_rnetout_migration(rout_male_1860to70, rin_male_1860to70)
rnetout_male_1870to80 <- compute_rnetout_migration(rout_male_1870to80, rin_male_1870to80)

rnetout_males <- list(
  "1850-60" = rnetout_male_1850to60,
  "1860-70" = rnetout_male_1860to70,
  "1870-80" = rnetout_male_1870to80
)

# Compute baseline net out-migration rate ---------------------------------

all.equal(
  rnetout_male_1850to60$STATE,
  rnetout_male_1870to80$STATE[!rnetout_male_1870to80$STATE%in%c("Kansas","Nevada")]
) # TRUE
ernetout_male_1860to70_basepre <- rnetout_male_1850to60
ernetout_male_1860to70_basepost <- rnetout_male_1870to80

all.equal(
  rnetout_male_1850to60$STATE, 
  rnetout_male_1870to80[!rnetout_male_1870to80$STATE %in% c("Kansas","Nevada"),][["STATE"]]
) # TRUE

ernetout_male_1860to70_baseavg = (
  rnetout_male_1850to60[,c("5-14","15-24","25-34","35-44")] + 
    rnetout_male_1870to80[!rnetout_male_1870to80$STATE %in% c("Kansas","Nevada"),c("5-14","15-24","25-34","35-44")]
)/2
names(ernetout_male_1860to70_baseavg) <- c("5-14","15-24","25-34","35-44")
ernetout_male_1860to70_baseavg <- cbind.data.frame(STATE=rnetout_male_1850to60$STATE, ernetout_male_1860to70_baseavg)

ernetout_males <- list(
  "1850-60 Baseline" = ernetout_male_1860to70_basepre,
  "1870-80 Baseline" = ernetout_male_1860to70_basepost,
  "Average Baseline" = ernetout_male_1860to70_baseavg
)

# Compute excess net out-migration rate -----------------------------------

excrnetout_male_1860to70_basepre <- rnetout_male_1860to70[!rnetout_male_1860to70$STATE %in% c("Kansas","Nevada"),-1] - ernetout_male_1860to70_basepre[,-1]
excrnetout_male_1860to70_basepre <- cbind.data.frame(STATE = ernetout_male_1860to70_basepre$STATE, excrnetout_male_1860to70_basepre)

excrnetout_male_1860to70_basepost <- rnetout_male_1860to70[,-1] - ernetout_male_1860to70_basepost[,-1]
excrnetout_male_1860to70_basepost <- cbind.data.frame(STATE = ernetout_male_1860to70_basepost$STATE, excrnetout_male_1860to70_basepost)

excrnetout_male_1860to70_baseavg <- rnetout_male_1860to70[!rnetout_male_1860to70$STATE %in% c("Kansas","Nevada"),-1] - ernetout_male_1860to70_baseavg[,-1]
excrnetout_male_1860to70_baseavg <- cbind.data.frame(STATE = ernetout_male_1860to70_baseavg$STATE, excrnetout_male_1860to70_baseavg)

excrnetout_males <- list(
  "1850-60 Baseline" = excrnetout_male_1860to70_basepre,
  "1870-80 Baseline" = excrnetout_male_1860to70_basepost,
  "Average Baseline" = excrnetout_male_1860to70_baseavg
)

# Compute baseline raw mortality rate -------------------------------------

all.equal(rsurvival_male_1850to60$STATE, 
          rsurvival_male_1870to80$STATE[!rsurvival_male_1870to80$STATE%in%c("Kansas","Nevada")]
) # TRUE
ersurvival_male_1860to70_basepre <- rsurvival_male_1850to60
ersurvival_male_1860to70_basepost <- rsurvival_male_1870to80

ersurvival_male_1860to70_baseavg = (
  rsurvival_male_1850to60[,c("5-14","15-24","25-34","35-44")] + 
    rsurvival_male_1870to80[!rsurvival_male_1870to80$STATE %in% c("Kansas","Nevada"),c("5-14","15-24","25-34","35-44")]
)/2
names(ersurvival_male_1860to70_baseavg) <- c("5-14","15-24","25-34","35-44")
ersurvival_male_1860to70_baseavg <- cbind.data.frame(STATE=ersurvival_male_1860to70_basepre$STATE, GROUP = ersurvival_male_1860to70_basepre$GROUP, ersurvival_male_1860to70_baseavg)

ersurvival_males <- list(
  "1850-60 Baseline" = ersurvival_male_1860to70_basepre,
  "1870-80 Baseline" = ersurvival_male_1860to70_basepost,
  "Average Baseline" = ersurvival_male_1860to70_baseavg
) 

# Compute excess raw excess mortality rate --------------------------------

## Note: We need the raw rates only, no need for counts

all.equal(
  ersurvival_male_1860to70_basepre$STATE, 
  rsurvival_male_1860to70$STATE[!rsurvival_male_1860to70$STATE %in% c("Kansas","Nevada")]
) # TRUE
excrrmort_male_1860to70_basepre <- ersurvival_male_1860to70_basepre[,-c(1:2)] - rsurvival_male_1860to70[!rsurvival_male_1860to70$STATE %in% c("Kansas","Nevada"), -c(1:2)]
excrrmort_male_1860to70_basepre <- cbind.data.frame(STATE = ersurvival_male_1860to70_basepre$STATE, excrrmort_male_1860to70_basepre)
excrrmort_male_1860to70_basepre <- excrrmort_male_1860to70_basepre[order(excrrmort_male_1860to70_basepre$STATE),]

all.equal(
  ersurvival_male_1860to70_basepost$STATE, 
  rsurvival_male_1860to70$STATE
) # TRUE
excrrmort_male_1860to70_basepost <- ersurvival_male_1860to70_basepost[,-c(1,2)] - rsurvival_male_1860to70[,-c(1,2)]
excrrmort_male_1860to70_basepost <- cbind.data.frame(STATE = ersurvival_male_1860to70_basepost$STATE, excrrmort_male_1860to70_basepost)
excrrmort_male_1860to70_basepost <- excrrmort_male_1860to70_basepost[order(excrrmort_male_1860to70_basepost$STATE),]

all.equal(
  ersurvival_male_1860to70_baseavg$STATE,  
  rsurvival_male_1860to70$STATE[!rsurvival_male_1860to70$STATE %in% c("Kansas","Nevada")]
) # TRUE
excrrmort_male_1860to70_baseavg = ersurvival_male_1860to70_baseavg[,-c(1,2)] - rsurvival_male_1860to70[!rsurvival_male_1860to70$STATE %in% c("Kansas","Nevada"), -c(1:2)]
excrrmort_male_1860to70_baseavg <- cbind.data.frame(STATE = ersurvival_male_1860to70_baseavg$STATE, excrrmort_male_1860to70_baseavg)
excrrmort_male_1860to70_baseavg <- excrrmort_male_1860to70_baseavg[order(excrrmort_male_1860to70_baseavg$STATE),]

excrrmort_males <- list(
  "1850-60 Baseline" = excrrmort_male_1860to70_basepre,
  "1870-80 Baseline" = excrrmort_male_1860to70_basepost,
  "Average Baseline" = excrrmort_male_1860to70_baseavg
)

# Compute adjusted excess mortality ---------------------------------------

compute_excess_mortality_with_totals <- function(ntot = ntotal_male_1860, rmort, out = c("r", "n")) {
  ntot = ntot[order(ntot[["STATE"]]),]; rmort = rmort[order(rmort[["STATE"]]),]
  if (!all.equal(ntot[["STATE"]], rmort[["STATE"]])) stop("State orders do not match!")
  else {
    n_c = ntot[,-c(1,2)]
    n_tot = rowSums(n_c)
    r_mort = rmort[,-1]
    n_mort = n_c * r_mort
    names(r_mort) <- c("5-14","15-24","25-34","35-44")
    names(n_mort) <- c("5-14","15-24","25-34","35-44")
    n_mort_tot = rowSums(n_mort)
    r_mort_tot = n_mort_tot/n_tot
    if (out=="n") {
      df_nmort = cbind.data.frame(
        STATE = ntot[["STATE"]], n_mort,
        POP = n_tot, TOTAL = n_mort_tot
      )
      df_nmort = df_nmort[order(df_nmort[["STATE"]]),]
      return(df_nmort)
    }
    if (out=="r") {
      df_rmort = cbind.data.frame(
        STATE = rmort[["STATE"]], r_mort,
        POP = n_tot, TOTAL = r_mort_tot
      )
    }
  } 
}

## Compute adjusted excess mortality rate ----

all.equal(
  excrrmort_male_1860to70_basepre$STATE, 
  excrnetout_male_1860to70_basepre$STATE
) # TRUE
excrmort_male_1860to70_basepre <- excrrmort_male_1860to70_basepre[,-1] - excrnetout_male_1860to70_basepre[,-1]
excrmort_male_1860to70_basepre <- cbind.data.frame(STATE = excrrmort_male_1860to70_basepre$STATE, excrmort_male_1860to70_basepre)

all.equal(
  excrrmort_male_1860to70_basepost$STATE, 
  excrnetout_male_1860to70_basepost$STATE
) # TRUE
excrmort_male_1860to70_basepost <- excrrmort_male_1860to70_basepost[,-1] - excrnetout_male_1860to70_basepost[,-1]
excrmort_male_1860to70_basepost <- cbind.data.frame(STATE = excrrmort_male_1860to70_basepost$STATE, excrmort_male_1860to70_basepost)

all.equal(
  excrrmort_male_1860to70_baseavg$STATE, 
  excrnetout_male_1860to70_baseavg$STATE
) # TRUE
excrmort_male_1860to70_baseavg <- excrrmort_male_1860to70_baseavg[,-1] - excrnetout_male_1860to70_baseavg[,-1]
excrmort_male_1860to70_baseavg <- cbind.data.frame(STATE = excrrmort_male_1860to70_baseavg$STATE, excrmort_male_1860to70_baseavg)

texcrmort_male_1860to70_basepre <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860 |> dplyr::filter(!STATE %in% c("Kansas","Nevada")),
  rmort = excrmort_male_1860to70_basepre, out = "r")
texcrmort_male_1860to70_basepost <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860,
  rmort = excrmort_male_1860to70_basepost, out = "r")
texcrmort_male_1860to70_baseavg <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860 |> dplyr::filter(!STATE %in% c("Kansas","Nevada")),
  rmort = excrmort_male_1860to70_baseavg, out = "r")

texcrmort_males <- list(
  "1850-60 Baseline" = texcrmort_male_1860to70_basepre,
  "1870-80 Baseline" = texcrmort_male_1860to70_basepost,
  "Average Baseline" = texcrmort_male_1860to70_baseavg
)

## Compute adjusted excess mortality counts ----

texcnmort_male_1860to70_basepre <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860 |> dplyr::filter(!STATE %in% c("Kansas","Nevada")),
  rmort = excrmort_male_1860to70_basepre, out = "n")
texcnmort_male_1860to70_basepost <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860,
  rmort = excrmort_male_1860to70_basepost, out = "n")
texcnmort_male_1860to70_baseavg <- compute_excess_mortality_with_totals(
  ntot = ntotal_male_1860 |> dplyr::filter(!STATE %in% c("Kansas","Nevada")),
  rmort = excrmort_male_1860to70_baseavg, out = "n")

texcnmort_males <- list(
  "1850-60 Baseline" = texcnmort_male_1860to70_basepre,
  "1870-80 Baseline" = texcnmort_male_1860to70_basepost,
  "Average Baseline" = texcnmort_male_1860to70_baseavg
)

# Add state groups to adjusted estimates ----------------------------------

add_state_groups <- function(tab, state_group_df = state_groups) {
  if ("GROUP" %in% names(tab)) {
    tab|>
      dplyr::mutate(GROUP=factor(GROUP, levels = c("Union","Border","Confederate"))) |>
      dplyr::arrange(GROUP, STATE)
  } 
  else {
    dplyr::left_join(tab, state_group_df, by = "STATE") |>
      dplyr::relocate(GROUP, .after = STATE) |>
      dplyr::mutate(GROUP=factor(GROUP, levels = c("Union","Border","Confederate"))) |>
      dplyr::arrange(GROUP, STATE)
  }
}

texcrmort_males <- lapply(texcrmort_males, add_state_groups)
texcnmort_males <- lapply(texcnmort_males, add_state_groups)

# Export excess deaths estimates ------------------------------------------

writexl::write_xlsx(texcrmort_males, "EST_excdeaths_rate_bystate.xlsx")
writexl::write_xlsx(texcnmort_males, "EST_excdeaths_count_bystate.xlsx")

# END SCRIPT HERE
